package z1add_field

import (
	"reflect"

	"gitee.com/z1gotool/z1err"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
)

// https://www.cnblogs.com/xinliangcoder/p/14316509.html
// https://github.com/xinliangnote/go-gin-api/blob/master/internal/repository/mysql/plugin.go
// https://kgithub.com/xinliangnote/go-gin-api/blob/master/internal/repository/mysql/plugin.go
// 使用db.WithContext(ctx)把需要添加的内容传递给callBackBefore z1AddField=map[string]intertface{}{`palce_id`:735}
// db = db.WithContext(ctx) // 必须用变量接收返回值，不然WithContext不生效，有点向返回了值
// db = db.WithContext(context.WithValue(
//
//	context.Background(),
//	`z1AddField`,
//	map[string]interface{}{
//		`palce_id`: 735,
//	},
//
// ))
// db.Statement.Context.Value(`z1key3`)

const (
	callBackBeforeNamAddField = "z1_add_field:before"
	callBackAfterNameAddField = "z1_add_field:after"
)

type Z1AddField struct{}

func (plugin *Z1AddField) Name() string {
	return "Z1AddField"
}

func (plugin *Z1AddField) Initialize(db *gorm.DB) (err error) {
	z1beforeCreate := func(db *gorm.DB) {
		tmp := db.Statement.Context.Value(`z1AddFieldInsert`)
		// log.Println(`--------z1AddField---------`, tmp)
		if tmp != nil {
			if goAddField(plugin, db) {
				addField := tmp.(map[string]interface{})
				ModifyStatementCreate(db, addField)
			}
		}
	}
	z1before := func(db *gorm.DB) {
		tmp := db.Statement.Context.Value(`z1AddFieldWhere`)
		// log.Println(`--------z1AddField---------`, tmp)
		if tmp != nil {
			if goAddField(plugin, db) {
				addField := tmp.(map[string]interface{})
				ModifyStatementWhere(db, addField)
			}
		}
	}

	db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeNamAddField, z1beforeCreate)
	db.Callback().Query().Before("gorm:query").Register(callBackBeforeNamAddField, z1before)
	db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeNamAddField, z1before)
	db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeNamAddField, z1before)

	return
}

func ModifyStatementWhere(db *gorm.DB, z1AddField map[string]interface{}) {
	// C:\Users\woogle\go\pkg\mod\gorm.io\plugin\soft_delete@v1.0.3\soft_delete.go

	stmt := db.Statement
	if c, ok := stmt.Clauses["WHERE"]; ok {
		if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
			for _, expr := range where.Exprs {
				if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
					where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
					c.Expression = where
					stmt.Clauses["WHERE"] = c
					break
				}
			}
		}
	}

	for key, value := range z1AddField {
		stmt.AddClause(clause.Where{Exprs: []clause.Expression{
			clause.Eq{
				Column: clause.Column{
					Table: clause.CurrentTable,
					Name:  key,
				},
				Value: value,
			},
		}})
	}

}

func ModifyStatementCreate(db *gorm.DB, z1AddField map[string]interface{}) {
	// https://gorm.io/zh_CN/docs/write_plugins.html

	if db.Statement.Schema != nil {
		switch db.Statement.ReflectValue.Kind() {
		case reflect.Slice, reflect.Array:
			for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
				for key, value := range z1AddField {
					field := db.Statement.Schema.LookUpField(key)
					if field != nil {
						if _, isZero := field.ValueOf(db.Statement.ReflectValue.Index(i)); isZero {
							err := field.Set(db.Statement.ReflectValue.Index(i), value)
							z1err.Check(err)
						}
					}
				}
			}
		case reflect.Struct:
			for key, value := range z1AddField {
				field := db.Statement.Schema.LookUpField(key)
				if field != nil {
					if _, isZero := field.ValueOf(db.Statement.ReflectValue); isZero {
						err := field.Set(db.Statement.ReflectValue, value)
						z1err.Check(err)
					}
				}
			}
		}
	}
}

func goAddField(plugin *Z1AddField, db *gorm.DB) (yes bool) {
	return true
}
